﻿using System.Linq;
using System.Text;
using System.Collections.Generic;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;

namespace Bit.BlazorUI.SourceGenerators.Component;

[Generator]
public class ComponentSourceGenerator : ISourceGenerator
{
    public void Initialize(GeneratorInitializationContext context)
    {
        context.RegisterForSyntaxNotifications(() => new ComponentSyntaxContextReceiver());
    }

    public void Execute(GeneratorExecutionContext context)
    {
        if (context.SyntaxContextReceiver is not ComponentSyntaxContextReceiver receiver) return;

        foreach (var parametersGroup in receiver.Parameters.GroupBy(p => p.PropertySymbol.ContainingType, SymbolEqualityComparer.Default))
        {
            var parameters = parametersGroup.ToList();

            if (parametersGroup.Key == null) continue;

            string classSource = GeneratePartialClass((INamedTypeSymbol)parametersGroup.Key, parameters);
            context.AddSource($"{parametersGroup.Key.Name}_SetParametersAsync.AutoGenerated.cs", SourceText.From(classSource, Encoding.UTF8));
        }
    }

    private static string GeneratePartialClass(INamedTypeSymbol classSymbol, List<BlazorParameter> parameters)
    {
        var namespaceName = classSymbol.ContainingNamespace.ToDisplayString();
        var className = GetClassName(classSymbol);
        var twoWayParameters = parameters.Where(p => p.IsTwoWayBound).ToArray();
        var isBaseTypeComponentBase = classSymbol.BaseType?.ToDisplayString() == "Microsoft.AspNetCore.Components.ComponentBase";
        var doesSupporteParametersViewCache = InheritsFromBitComponentBase(classSymbol);

        StringBuilder builder = new StringBuilder($@"using System;
using System.Threading.Tasks;
using System.Collections.Generic;
using Microsoft.AspNetCore.Components;
using Microsoft.AspNetCore.Components.Web;

namespace {namespaceName}
{{
    public partial class {className}
    {{
");
        builder.AppendLine("        private readonly HashSet<string> __assignedParameters = [];");
        builder.AppendLine("");
        foreach (var par in twoWayParameters)
        {
            var sym = par.PropertySymbol;
            builder.AppendLine($"        private bool {sym.Name}HasBeenSet;");
            builder.AppendLine($"        [Parameter] public EventCallback<{sym.Type.ToDisplayString()}> {sym.Name}Changed {{ get; set; }}");
        }
        if (twoWayParameters.Length > 0) builder.AppendLine("");
        builder.AppendLine($@"        [global::System.Diagnostics.DebuggerNonUserCode]
        [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
        public override async Task SetParametersAsync(ParameterView parameters)
        {{");
        builder.AppendLine($"            __assignedParameters.Clear();");
        foreach (var par in twoWayParameters)
        {
            builder.AppendLine($"            {par.PropertySymbol.Name}HasBeenSet = false;");
        }
        if (doesSupporteParametersViewCache)
        {
            builder.AppendLine("            var parametersDictionary = (ParametersCache ??= parameters.ToDictionary() as Dictionary<string, object>);");
        }
        else
        {
            builder.AppendLine("            var parametersDictionary = parameters.ToDictionary() as Dictionary<string, object>;");
        }
        builder.AppendLine("            foreach (var parameter in parametersDictionary!)");
        builder.AppendLine("            {");
        builder.AppendLine("                switch (parameter.Key)");
        builder.AppendLine("                {");
        foreach (var par in parameters)
        {
            var sym = par.PropertySymbol;
            var paramName = sym.Name;
            var varName = $"@{paramName.ToLower()}";
            var paramType = sym.Type.ToDisplayString();
            builder.AppendLine($"                    case nameof({paramName}):");
            builder.AppendLine($"                       __assignedParameters.Add(nameof({paramName}));");
            if (par.IsTwoWayBound)
            {
                builder.AppendLine($"                       {paramName}HasBeenSet = true;");
            }
            builder.AppendLine($"                       var {varName} = parameter.Value is null ? default! : ({paramType})parameter.Value;");
            if (par.ResetClassBuilder || par.ResetStyleBuilder || string.IsNullOrWhiteSpace(par.CallOnSetMethodName) is false || string.IsNullOrWhiteSpace(par.CallOnSetAsyncMethodName) is false)
            {
                builder.AppendLine($"                       var notEquals{paramName} = EqualityComparer<{paramType}>.Default.Equals({paramName}, {varName}) is false;");
            }
            builder.AppendLine($"                       {paramName} = {varName};");
            if (par.ResetClassBuilder)
            {
                builder.AppendLine($"                       if (notEquals{paramName}) ClassBuilder.Reset();");
            }
            if (par.ResetStyleBuilder)
            {
                builder.AppendLine($"                       if (notEquals{paramName}) StyleBuilder.Reset();");
            }
            if (string.IsNullOrWhiteSpace(par.CallOnSetMethodName) is false)
            {
                builder.AppendLine($"                       if (notEquals{paramName}) {par.CallOnSetMethodName}();");
            }
            if (string.IsNullOrWhiteSpace(par.CallOnSetAsyncMethodName) is false)
            {
                builder.AppendLine($"                       if (notEquals{paramName}) await {par.CallOnSetAsyncMethodName}();");
            }
            builder.AppendLine("                       parametersDictionary.Remove(parameter.Key);");
            builder.AppendLine("                       break;");
            if (par.IsTwoWayBound)
            {
                paramName = $"{paramName}Changed";
                varName = $"@{paramName.ToLower()}";
                builder.AppendLine($"                    case nameof({paramName}):");
                builder.AppendLine($"                       var {varName} = parameter.Value is null ? default! : (EventCallback<{sym.Type.ToDisplayString()}>)parameter.Value;");
                builder.AppendLine($"                       {paramName} = {varName};");
                builder.AppendLine("                       parametersDictionary.Remove(parameter.Key);");
                builder.AppendLine("                       break;");
            }
        }
        builder.AppendLine("                }");
        builder.AppendLine("            }");
        if (isBaseTypeComponentBase)
        {
            builder.AppendLine("            await base.SetParametersAsync(ParameterView.Empty);");
        }
        else
        {
            if (doesSupporteParametersViewCache)
            {
                builder.AppendLine("            await base.SetParametersAsync(ParameterView.Empty);");
            }
            else
            {
                builder.AppendLine("            await base.SetParametersAsync(ParameterView.FromDictionary(parametersDictionary as IDictionary<string, object?>));");
            }
        }
        builder.AppendLine(@"        }");

        builder.AppendLine("");

        builder.AppendLine($@"        [global::System.Diagnostics.DebuggerNonUserCode]
        [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
        public bool HasNotBeenSet(string name)
        {{");
        builder.AppendLine("            return __assignedParameters.Contains(name) is false;");
        builder.AppendLine("        }");

        if (twoWayParameters.Length > 0) builder.AppendLine("");
        foreach (var par in twoWayParameters)
        {
            var paramName = par.PropertySymbol.Name;
            var paramType = par.PropertySymbol.Type.ToDisplayString();
            builder.AppendLine($@"        [global::System.Diagnostics.DebuggerNonUserCode]
        [global::System.Diagnostics.CodeAnalysis.ExcludeFromCodeCoverage]
        public async Task<bool> Assign{paramName}({paramType} value)
        {{");
            builder.AppendLine($"            if ({paramName}HasBeenSet && {paramName}Changed.HasDelegate is false) return false;");
            builder.AppendLine($"            if (EqualityComparer<{paramType}>.Default.Equals({paramName}, value) is false)");
            builder.AppendLine("            {");
            builder.AppendLine($"                {paramName} = value;");
            builder.AppendLine($"                await {paramName}Changed.InvokeAsync(value);");
            if (par.ResetClassBuilder)
            {
                builder.AppendLine($"                ClassBuilder.Reset();");
            }
            if (par.ResetStyleBuilder)
            {
                builder.AppendLine($"                StyleBuilder.Reset();");
            }
            if (string.IsNullOrWhiteSpace(par.CallOnSetMethodName) is false)
            {
                builder.AppendLine($"                {par.CallOnSetMethodName}();");
            }
            if (string.IsNullOrWhiteSpace(par.CallOnSetAsyncMethodName) is false)
            {
                builder.AppendLine($"                await {par.CallOnSetAsyncMethodName}();");
            }
            builder.AppendLine("            }");
            builder.AppendLine($"            return true;");
            builder.AppendLine("        }");
        }

        builder.AppendLine("    }");
        builder.AppendLine("}");

        return builder.ToString();
    }

    private static string GetClassName(INamedTypeSymbol classSymbol)
    {
        StringBuilder sbName = new StringBuilder(classSymbol.Name);

        if (classSymbol.IsGenericType)
        {
            sbName.Append('<');
            sbName.Append(string.Join(", ", classSymbol.TypeArguments.Select(s => s.Name)));
            sbName.Append('>');
        }

        return sbName.ToString();
    }

    private static bool InheritsFromBitComponentBase(INamedTypeSymbol? typeSymbol)
    {
        if (typeSymbol is null)
            return false;

        if (typeSymbol.TypeKind is not TypeKind.Class)
            return false;

        if (typeSymbol.Name == "BitComponentBase")
            return true;

        return InheritsFromBitComponentBase(typeSymbol.BaseType);
    }
}
